﻿#pragma kernel BitonicSort

const uint numEntries;
const uint groupWidth;
const uint groupHeight;
const uint stepIndex;

RWStructuredBuffer<float> Keys;
RWStructuredBuffer<float> Values;

[numthreads(128,1,1)]
void BitonicSort(uint3 id : SV_DispatchThreadID) 
{
    uint i = id.x;

    uint hIndex = i & (groupWidth - 1);
    uint indexLeft = hIndex + (groupHeight + 1) * (i / groupWidth);
    uint rightStepSize = stepIndex == 0 ? groupHeight - 2 * hIndex : (groupHeight + 1) / 2;
    uint indexRight = indexLeft + rightStepSize;

    // Exit if out of bounds (for non-power of 2 input sizes)
    if (indexRight >= numEntries) return;

    float keyLeft = Keys[indexLeft];
    float keyRight = Keys[indexRight];

    float valueLeft = Values[indexLeft];
    float valueRight = Values[indexRight]; 

    // Swap entries if value is descending
    if (valueLeft > valueRight)
    {
        Keys[indexLeft] = keyRight;
        Keys[indexRight] = keyLeft;

        Values[indexLeft] = valueRight;
        Values[indexRight] = valueLeft;
    }
}